import os
import json
import requests

from openai import OpenAI
import numpy as np

from diffgro.kinematic_llm.prompt import *
from diffgro.kinematic_llm.parser import parse_kinematic


API_TOKEN = "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"

headers = {"Authorization": f"Bearer {API_TOKEN}"}
API_URL = "https://api-inference.huggingface.co/models/google/gemma-7b"

os.environ["OPENAI_KEY"] = "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"
CLIENT = OpenAI(api_key=os.getenv("OPENAI_KEY"))


def call_chatgpt(prompt: str):
    sample = CLIENT.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": prompt}],
        temperature=0.0,
    )
    return sample.choices[0].message.content


def query(payload):
    response = requests.request("POST", API_URL, headers=headers, json=payload)
    return json.loads(response.content.decode("utf-8"))


def parse_waypoints(data):
    data = [d.strip() for d in data.split("\n")]
    data = [d for d in data if d]

    data.reverse()

    try:
        idx = data.index("### 3D Manipulation Waypoints:")
        data = data[:idx]
        data.reverse()
    except:
        return None

    waypoints = []
    for d in data:
        try:
            if d.startswith("[MOVE_X]"):
                args = d.split()
                waypoints.append(("MOVE_X", float(args[1]), "grab" in d))
            elif d.startswith("[MOVE_Y]"):
                args = d.split()
                waypoints.append(("MOVE_Y", float(args[1]), "grab" in d))
            elif d.startswith("[MOVE_Z]"):
                args = d.split()
                waypoints.append(("MOVE_Z", float(args[1]), "grab" in d))
            elif d.startswith("[PUSH]"):
                target_x, target_y, target_z = map(float, d.split()[1:4])
                waypoints.append(
                    ("PUSH", np.array([target_x, target_y, target_z]), "grab" in d)
                )
            elif d.startswith("[PULL]"):
                target_x, target_y, target_z = map(float, d.split()[-3:])
                waypoints.append(
                    ("PULL", np.array([target_x, target_y, target_z]), "grab" in d)
                )
            elif d.startswith("[GRIP]"):
                waypoints.append(("GRIP", None))
        except Exception as e:
            print(e)
            return None

    print(waypoints)
    return waypoints


def plan_waypoints(domain_name, task_name, env, obs_0, context):
    context_prompt = ""
    if context is not None:
        context_prompt = get_context_prompt(context, domain_name, task_name)
        context_prompt = ""

    prompt = f"""
# Refer to the Example

{get_demonstrations(task_name)}

--------------------------------------------------------------------------

# Now it is your turn to generate the waypoints for the specified task. 
# You should fill in Abstract Manipulation Sequence and 3D Manipulation Waypoints sections.

{context_prompt}

### Task Instruction: 
    {get_task_instructions(domain_name, task_name)}

### Object Kinematic Knowledge:
    {parse_kinematic(domain_name, task_name, env, obs_0)}

{guidance_prompt}

{action_prompt}
"""
    data = call_chatgpt(prompt)
    waypoints = parse_waypoints(data)

    while not waypoints:
        print("Invalid waypoints. Trying again.")
        data = call_chatgpt(prompt)
        waypoints = parse_waypoints(data)

    return waypoints
